A Little Idea about Middleware
在编写服务的 server 或者 client 时,我们都会用到中间件( middleware )。 中间件极高的复用性以及可组合性极大地拓展了它的用武之地。 本文旨在介绍一种采用 CPS 风格的 middleware 写法。
一般来说, middleware 出于和 handler 可组合性的考量都会将类型设计得尽可能贴近 handler 。
一个极端是 github.com/gin-gonic/gin , middleware 直接采用 handler 的类型。
比如 "github.com/gin-gonic/gin".(*Engine).Use
的定义:
func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes
更一般的, middleware 会将类型设计为类似 func(nextHandler Handler) Handler
的形式。
举几个例子:
// package: "github.com/go-chi/chi/middleware" func Logger(next http.Handler) http.Handler // package: "github.com/gorilla/handlers" func CompressHandler(h http.Handler) http.Handler // package: "google.golang.org/grpc" type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)
前两个都很标准。 最后 gRPC 的做法虽然不是那么标准,但是考虑到 middleware (被称为 interceptor )返回了 handler 的返回值,并且也额外接受了 handler 的参数,进行一个类似科里化的变化就能变成那个形式了。
见识到了 func(nextHandler Handler) Handler
这种形式的普遍性后,下面说说这种形式的局限性。
由于对 nextHandler
的调用完全发生在 middleware 当中,所以之后 middleware 发生的 abort 难于让当前 middleware 感受到。
像 gin 一样自行封装 "github.com/gin-gonic/gin".Context
类型并直接提供 "github.com/gin-gonic/gin".(*Context).IsAborted
当然是一种可行的做法——但这要求要么 response 原生就提供 aborted 判断,要么得花额外的工作量来封装新类型。
其实这种局限只是控制流上的局限性,所以可以通过引入 CPS 风格来解决。
首先我们给出 middleware 的类型(注意到这里 Middleware
类型是递归的,想要单独类型递归可以看笔者旧文 Type Recursion in Golang 了解):
type Middleware func(req Request, resp *Response, nextMiddleware Middleware) func (mw Middleware) Pass(req Request, resp *Response, nextMiddleware Middleware) { if mw != nil { mw(req, resp, nextMiddleware) } }
这里之所以定义了一个 Middleware.Pass
方法来代替直接调用是因为这样就可以直接用 nil
来表示空 middleware 也就是零元。
这个 Middleware
类型和我们刚才介绍的常见的形式 func(nextHandler Handler) Handler
最大的区别其实就在于将 nextHandler Handler
参数替换成了 nextMiddleware Middleware
。
这样整个 middleware 系统就“自成一体”,(暂时地)抛开了 handler ,也就抛开了 handler 所受的一些限制。
在这个 middleware 系统中, middleware 写法的模板是这样的:
func DemoMiddleware(req Request, resp *Response, next Middleware) { DoSomePreJobs() if shouldNotProcessReq { *resp = someResp return } next.Pass(req, resp, CastMagic) DoSomePostJobs() }
DoSomePreJobs
和 DoSomePostJobs
都很好理解, CastMagic
有什么蹊跷之处么?
当然有, CastMagic
就是用来控制代码执行的控制流的“注入点”。
这么说有点尬吹了,不妨先看两个例子:
func AbortWithFn(f func()) Middleware { return func(req Request, resp *Response, next Middleware) { f() } } func LiftFn(f func()) Middleware { return func(req Request, resp *Response, next Middleware) { f() next.Pass(req, resp, nil) } }
在实际使用中,可以用 AbortWithFn
和 LiftFn
构造的 Middleware
来填充 CastMagic
的值,从而达到 abort 的效果或是在确保未进行 abort 的情况下观察 response 。
不妨先看一个简单的 middleware Logger
:
func Logger(req Request, resp *Response, next Middleware) { fmt.Println("Start") defer fmt.Println("End") next.Pass(req, resp, LiftFn(func() { fmt.Printf("Got resp: %q\n", *resp) })) }
实际上运行的时候,这些 middleware 一个接着一个靠 next
给串起来,而 next.Pass(req, resp, LiftFn(...))
相当于是动态地给当前的 middleware 串又 append 一个新的。
而所有后续的要执行的 middleware 都会包含在靠前的 middleware 对应的 nextMiddleware
里,所以一旦靠前的 middleware 不执行 nextMiddleware
,所有后面的 middleware 就都不会执行了,这也就是 AbortWithFn
的做法。
这里可能有点绕,不过后面的示例代码应该有比较明晰的展示效果。
在展示示例代码之前,还有两件事要做:一是将“(暂时)抛开的” handler 接回来,毕竟前文所述 middleware 已经自成一体了,对于 handler 需要做点工作才能融入这个体系;二是 middleware 的组合如何进行,实际上也是需要点技巧的。
方便起见,我们就挑 func(Request, *Response)
作为 Handler
的定义,并给出相应的用于将 Handler
提升为 Middleware
的函数 LiftHandler
:
type Handler func(Request, *Response) func LiftHandler(h Handler) Middleware { return func(req Request, resp *Response, next Middleware) { h(req, resp) next.Pass(req, resp, nil) } }
middleware 的组合当然是从右向左依次进行,不过最基本的组合两个 middleware 的操作会因为采用 CPS 风格而复杂一些。
先假装已经有了组合两个 middleware 的函数 Compose
,看如何拼装多个:
func Chain(mws ...Middleware) func(Handler) Handler { r := Middleware(nil) for _, mw := range mws { r = Compose(r, mw) } return func(h Handler) Handler { return func(req Request, resp *Response) { r.Pass(req, resp, LiftHandler(h)) } } }
接下来尝试给出 Compose
的实现。
首先回想一下 middleware 和 nextMiddleware
的关系: middleware 将需要交由后续 middleware 托管的代码逻辑通过作为 nextMiddleware
的 nextMiddleware
传递给 nextMiddleware.Pass()
来往后传递。
也就是说对于本层 middleware 而言, nextMiddleware
就是下一个 middleware ,而自己传给 nextMiddleware
的 nextMiddleware
却是 append 到目前这个 middleware 串的最末。
所以对于两个 middleware mw1
和 mw2
, Compose(mw1, mw2)
这个中间件需要对它的 nextMiddleware
(不妨称为 mw3
)传入原本 mw1
和 mw2
最后要做的事(不妨称为 next1
和 next2
),并且顺序是是 next2
再 next1
。
还是有点绕,不过结合代码多看看应该就理解了:
func Compose(mw1 Middleware, mw2 Middleware) Middleware { switch { case mw1 == nil: return mw2 case mw2 == nil: return mw1 default: return func(req Request, resp *Response, mw3 Middleware) { mw1.Pass(req, resp, func(req Request, resp *Response, next1 Middleware) { mw2.Pass(req, resp, func(req Request, resp *Response, next2 Middleware) { mw3.Pass(req, resp, Compose(next2, next1)) }) }) } } }
行文至此,这个 CPS 风格的 middleware 系统基本上就介绍完毕了。 最后上示例来看看效果:
type ( Request = int Response = string ) func main() { handler := func(req Request, resp *Response) { fmt.Println("Handler") *resp = strconv.FormatInt(int64(req), 10) } mws := make([]Middleware, 3) for i := range mws { i := i mws[i] = func(req Request, resp *Response, next Middleware) { fmt.Printf("MW%d start...\n", i) next.Pass(req, resp, LiftFn(func() { *resp = fmt.Sprintf("returned from MW%d: [%s]", i, *resp) fmt.Printf("MW%d end\n", i) })) } } abortAtRespMws := make([]Middleware, len(mws)) for i := range abortAtRespMws { i := i abortAtRespMws[i] = func(req Request, resp *Response, next Middleware) { fmt.Printf("MW%d start...\n", i) next.Pass(req, resp, AbortWithFn(func() { fmt.Printf("MW%d abort\n", i) })) } } // print out: // Start // MW0 start... // MW1 start... // MW2 start... // Handler // MW2 end // MW1 end // MW0 end // Got resp: "returned from MW0: [returned from MW1: [returned from MW2: [1]]]" // End Chain(Logger, mws[0], mws[1], mws[2])(handler)(1, new(string)) // print out: // Start // MW0 start... // MW1 start... // MW2 start... // Handler // MW2 end // MW1 end // MW0 abort // End Chain(Logger, abortAtRespMws[0], mws[1], mws[2])(handler)(2, new(string)) // print out: // Start // MW0 start... // MW1 start... // MW2 start... // Handler // MW2 end // MW1 abort // End Chain(Logger, mws[0], abortAtRespMws[1], mws[2])(handler)(3, new(string)) // print out: // Start // MW0 start... // MW1 start... // MW2 start... // Handler // MW2 abort // End Chain(Logger, mws[0], mws[1], abortAtRespMws[2])(handler)(4, new(string)) }